Cassava Leaf Disease Classification
Identify the type of disease present on Cassava Leaf image
This notebook is a simple training pipeline in TensorFlow for the Cassava Leaf Competition where we are given 21,397 labeled images of cassava leaves classified as 5 different groups (4 diseases and a healthy group) and asked to predict on unseen images of cassava leaves. As with most image classification problems, we can use and experiment with many different forms of augmentation and we can explore transfer learning.
import numpy as np
import pandas as pd
import seaborn as sns
import albumentations as A
import matplotlib.pyplot as plt
import os, gc, cv2, random, warnings
import re, math, sys, json, pprint, pdb
import tensorflow as tf
from tensorflow.keras import backend as K
import tensorflow_hub as hub
from sklearn.model_selection import train_test_split
warnings.simplefilter('ignore')
print(f"Using TensorFlow v{tf.__version__}")
#@title Notebook type { run: "auto", display-mode:"form" }
SEED = 16
DEBUG = False #@param {type:"boolean"}
TRAIN = True #@param {type:"boolean"}
def seed_everything(seed=0):
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['TF_DETERMINISTIC_OPS'] = '1'
GOOGLE = 'google.colab' in str(get_ipython())
KAGGLE = not GOOGLE
seed_everything(SEED)
print("Running on {}!".format(
"Google Colab" if GOOGLE else "Kaggle Kernel"
))
#@title {run: "auto", display-mode: "form" }
BASE_MODEL= 'efficientnet_b3' #@param ["'efficientnet_b3'", "'efficientnet_b4'", "'efficientnet_b2'"] {type:"raw", allow-input: true}
BATCH_SIZE = 32 #@param {type:"integer"}
HEIGHT = 300#@param {type:"number"}
WIDTH = 300#@param {type:"number"}
CHANNELS = 3#@param {type:"number"}
IMG_SIZE = (HEIGHT, WIDTH, CHANNELS)
EPOCHS = 8#@param {type:"number"}
print("Using {} with input size {}".format(BASE_MODEL, IMG_SIZE))
df = pd.read_csv(f'{input_path}train.csv')
df.head()
Check how many images are available in the training dataset and also check if each item in the training set are unique
The distribution of labels is obviously unbalanced as can be observed in the figure below.
Let's preprocess to add the directory string to the filename and rename the column to filename
df['filename'] = df['image_id'].map(lambda x : f'{input_path}train_images/{x}')
df = df.drop(columns = ['image_id'])
df = df.sample(frac=1).reset_index(drop=True)
df.head()
Let's find out what labels do we have for the 5 categories.
From the bar chart shown earlier, the label 3, Cassava Mosaic Disease (CMD) is the most common one. This imbalance may have to be addressed with a weighted loss function or oversampling. I might try this in a future iteration of this kernel or in a new kernel.
Let's check an example image to see what it looks like
Loading data
After my quick and rough EDA, let's load the PIL Image to a Numpy array, so we can move on to data augmentation.
In fastai, they have item_tfms and batch_tfms defined for their data loader API. The item transforms performs a fairly large crop to 224 and also apply other standard augmentations (in aug_tranforms) at the batch level on the GPU. The batch size is set to 32 here.
def count_data_items(filenames):
return np.sum([int(re.compile(r'-([0-9]*)\.').search(filename).group(1))
for filename in filenames])
Split the dataset into training set and validation set
filenames = tf.io.gfile.glob(f'{input_path}train_tfrecords/*.tfrec')
count_data_items(filenames)
count_data_items(filenames)
def transform_shear(image, height, shear):
# input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
# output - image randomly sheared
DIM = height
XDIM = DIM%2 #fix for size 331
shear = shear * tf.random.uniform([1],dtype='float32')
shear = math.pi * shear / 180.
# SHEAR MATRIX
one = tf.constant([1],dtype='float32')
zero = tf.constant([0],dtype='float32')
c2 = tf.math.cos(shear)
s2 = tf.math.sin(shear)
shear_matrix = tf.reshape(tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3])
# LIST DESTINATION PIXEL INDICES
x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
z = tf.ones([DIM*DIM],dtype='int32')
idx = tf.stack( [x,y,z] )
# ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
idx2 = K.dot(shear_matrix,tf.cast(idx,dtype='float32'))
idx2 = K.cast(idx2,dtype='int32')
idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
# FIND ORIGIN PIXEL VALUES
idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
d = tf.gather_nd(image, tf.transpose(idx3))
return tf.reshape(d,[DIM,DIM,3])
def transform_shift(image, height, h_shift, w_shift):
# input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
# output - image randomly shifted
DIM = height
XDIM = DIM%2 #fix for size 331
height_shift = h_shift * tf.random.uniform([1],dtype='float32')
width_shift = w_shift * tf.random.uniform([1],dtype='float32')
one = tf.constant([1],dtype='float32')
zero = tf.constant([0],dtype='float32')
# SHIFT MATRIX
shift_matrix = tf.reshape(tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3])
# LIST DESTINATION PIXEL INDICES
x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
z = tf.ones([DIM*DIM],dtype='int32')
idx = tf.stack( [x,y,z] )
# ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
idx2 = K.dot(shift_matrix,tf.cast(idx,dtype='float32'))
idx2 = K.cast(idx2,dtype='int32')
idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
# FIND ORIGIN PIXEL VALUES
idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
d = tf.gather_nd(image, tf.transpose(idx3))
return tf.reshape(d,[DIM,DIM,3])
def transform_rotation(image, height, rotation):
# input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
# output - image randomly rotated
DIM = height
XDIM = DIM%2 #fix for size 331
rotation = rotation * tf.random.uniform([1],dtype='float32')
# CONVERT DEGREES TO RADIANS
rotation = math.pi * rotation / 180.
# ROTATION MATRIX
c1 = tf.math.cos(rotation)
s1 = tf.math.sin(rotation)
one = tf.constant([1],dtype='float32')
zero = tf.constant([0],dtype='float32')
rotation_matrix = tf.reshape(tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3])
# LIST DESTINATION PIXEL INDICES
x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
z = tf.ones([DIM*DIM],dtype='int32')
idx = tf.stack( [x,y,z] )
# ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
idx2 = K.dot(rotation_matrix,tf.cast(idx,dtype='float32'))
idx2 = K.cast(idx2,dtype='int32')
idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
# FIND ORIGIN PIXEL VALUES
idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
d = tf.gather_nd(image, tf.transpose(idx3))
return tf.reshape(d,[DIM,DIM,3])
def random_cutout(image, height, width, channels=3, min_mask_size=(10, 10), max_mask_size=(80, 80), k=1):
assert height > min_mask_size[0]
assert width > min_mask_size[1]
assert height > max_mask_size[0]
assert width > max_mask_size[1]
for i in range(k):
mask_height = tf.random.uniform(shape=[], minval=min_mask_size[0], maxval=max_mask_size[0], dtype=tf.int32)
mask_width = tf.random.uniform(shape=[], minval=min_mask_size[1], maxval=max_mask_size[1], dtype=tf.int32)
pad_h = height - mask_height
pad_top = tf.random.uniform(shape=[], minval=0, maxval=pad_h, dtype=tf.int32)
pad_bottom = pad_h - pad_top
pad_w = width - mask_width
pad_left = tf.random.uniform(shape=[], minval=0, maxval=pad_w, dtype=tf.int32)
pad_right = pad_w - pad_left
cutout_area = tf.zeros(shape=[mask_height, mask_width, channels], dtype=tf.uint8)
cutout_mask = tf.pad([cutout_area], [[0,0],[pad_top, pad_bottom], [pad_left, pad_right], [0,0]], constant_values=1)
cutout_mask = tf.squeeze(cutout_mask, axis=0)
image = tf.multiply(tf.cast(image, tf.float32), tf.cast(cutout_mask, tf.float32))
return image
def data_augment_cutout(image, min_mask_size=(int(HEIGHT * .1), int(HEIGHT * .1)),
max_mask_size=(int(HEIGHT * .125), int(HEIGHT * .125))):
p_cutout = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
if p_cutout > .85: # 10~15 cut outs
n_cutout = tf.random.uniform([], 10, 15, dtype=tf.int32)
image = random_cutout(image, HEIGHT, WIDTH,
min_mask_size=min_mask_size, max_mask_size=max_mask_size, k=n_cutout)
elif p_cutout > .6: # 5~10 cut outs
n_cutout = tf.random.uniform([], 5, 10, dtype=tf.int32)
image = random_cutout(image, HEIGHT, WIDTH,
min_mask_size=min_mask_size, max_mask_size=max_mask_size, k=n_cutout)
elif p_cutout > .25: # 2~5 cut outs
n_cutout = tf.random.uniform([], 2, 5, dtype=tf.int32)
image = random_cutout(image, HEIGHT, WIDTH,
min_mask_size=min_mask_size, max_mask_size=max_mask_size, k=n_cutout)
else: # 1 cut out
image = random_cutout(image, HEIGHT, WIDTH,
min_mask_size=min_mask_size, max_mask_size=max_mask_size, k=1)
return image
def data_augment(image, label):
p_rotation = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_pixel_1 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_pixel_2 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_pixel_3 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_shear = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_cutout = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
# Shear
if p_shear > .2:
if p_shear > .6:
image = transform_shear(image, HEIGHT, shear=20.)
else:
image = transform_shear(image, HEIGHT, shear=-20.)
# Rotation
if p_rotation > .2:
if p_rotation > .6:
image = transform_rotation(image, HEIGHT, rotation=45.)
else:
image = transform_rotation(image, HEIGHT, rotation=-45.)
# Flips
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
if p_spatial > .75:
image = tf.image.transpose(image)
# Rotates
if p_rotate > .75:
image = tf.image.rot90(image, k=3) # rotate 270º
elif p_rotate > .5:
image = tf.image.rot90(image, k=2) # rotate 180º
elif p_rotate > .25:
image = tf.image.rot90(image, k=1) # rotate 90º
# Pixel-level transforms
if p_pixel_1 >= .4:
image = tf.image.random_saturation(image, lower=.7, upper=1.3)
if p_pixel_2 >= .4:
image = tf.image.random_contrast(image, lower=.8, upper=1.2)
if p_pixel_3 >= .4:
image = tf.image.random_brightness(image, max_delta=.1)
# Crops
if p_crop > .6:
if p_crop > .9:
image = tf.image.central_crop(image, central_fraction=.5)
elif p_crop > .8:
image = tf.image.central_crop(image, central_fraction=.6)
elif p_crop > .7:
image = tf.image.central_crop(image, central_fraction=.7)
else:
image = tf.image.central_crop(image, central_fraction=.8)
elif p_crop > .3:
crop_size = tf.random.uniform([], int(HEIGHT*.6), HEIGHT, dtype=tf.int32)
image = tf.image.random_crop(image, size=[crop_size, crop_size, CHANNELS])
image = tf.image.resize(image, size=[HEIGHT, WIDTH])
if p_cutout > .5:
image = data_augment_cutout(image)
return image, label
def decode_image(image_data):
image = tf.image.decode_jpeg(image_data, channels=3)
return image
def scale_image(image, label):
image = tf.cast(image, tf.float32) / 255.0
return image, label
def resize_image(image, label):
image = tf.image.resize(image, [HEIGHT, WIDTH])
image = tf.reshape(image, IMG_SIZE)
return image, label
def read_tfrecord(example, nclasses=None, labeled=True):
if labeled:
TFREC_SCHEME = {
'image': tf.io.FixedLenFeature([], tf.string),
'target': tf.io.FixedLenFeature([], tf.int64),
}
else:
TFREC_SCHEME = {
'image': tf.io.FixedLenFeature([], tf.string),
'image_name': tf.io.FixedLenFeature([], tf.string),
}
example = tf.io.parse_single_example(example, TFREC_SCHEME)
image = decode_image(example['image'])
label = (tf.one_hot(tf.cast(example['target'], tf.int32), nclasses)
if labeled else example['image_name'])
return image, label
def get_dataset(filenames, ordered=True, labeled=True,
aug=None, cached=False, nclasses=NCLASSES,
repeated=False, bs=BATCH_SIZE, auto=AUTOTUNE):
if ordered:
ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=auto)
else:
ds = tf.data.Dataset.list_files(filenames)
ds = ds.interleave(tf.data.TFRecordDataset, num_parallel_calls=auto)
options = tf.data.Options()
if ordered: options.experimental_deterministic = False
ds = ds.with_options(options)
# todo
ds = ds.map(lambda x: read_tfrecord(x, labeled=labeled, nclasses=nclasses),
num_parallel_calls=auto)
if aug: ds = ds.map(aug, num_parallel_calls=auto)
ds = ds.map(scale_image, num_parallel_calls=auto)
ds = ds.map(resize_image, num_parallel_calls=auto)
if not ordered: ds.shuffle(2048)
if repeated: ds = ds.repeat()
ds = ds.batch(bs)
if cached: ds = ds.cache()
ds = ds.prefetch(auto)
return ds
train_ds = get_dataset(filenames)
def show_images(ds):
_,axs = plt.subplots(3,3,figsize=(16,16))
for ((x, y), ax) in zip(ds.take(9), axs.flatten()):
ax.imshow(x.numpy().astype(np.uint8))
ax.set_title(np.argmax(y))
ax.axis('off')
train_iter = iter(train_ds.unbatch().batch(20))
display_batch_of_images(next(train_iter))
Show some training images
np.set_printoptions(threshold=15, linewidth=80)
def batch_to_numpy_images_and_labels(data):
images, labels = data
numpy_images = images.numpy()
numpy_labels = labels.numpy()
if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
numpy_labels = [None for _ in enumerate(numpy_images)]
# If no labels, only image IDs, return None for labels (this is the case for test data)
return numpy_images, numpy_labels
def title_from_label_and_target(label, correct_label):
if correct_label is None:
return CLASSES[label], True
correct = (label == correct_label)
return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
CLASSES[correct_label] if not correct else ''), correct
def display_one_flower(image, title, subplot, red=False, titlesize=16):
plt.subplot(*subplot)
plt.axis('off')
plt.imshow(image)
if len(title) > 0:
plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black',
fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
return (subplot[0], subplot[1], subplot[2]+1)
def display_batch_of_images(databatch, predictions=None):
"""This will work with:
display_batch_of_images(images)
display_batch_of_images(images, predictions)
display_batch_of_images((images, labels))
display_batch_of_images((images, labels), predictions)
"""
# data
images, labels = batch_to_numpy_images_and_labels(databatch)
labels = np.argmax(labels, axis=-1)
if labels is None:
labels = [None for _ in enumerate(images)]
# auto-squaring: this will drop data that does not fit into square or square-ish rectangle
rows = int(math.sqrt(len(images)))
cols = len(images)//rows
# size and spacing
FIGSIZE = 13.0
SPACING = 0.1
subplot=(rows,cols,1)
if rows < cols:
plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
else:
plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
# display
for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
title = '' if label is None else CLASSES[label]
correct = True
if predictions is not None:
title, correct = title_from_label_and_target(predictions[i], label)
dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
#layout
plt.tight_layout()
if label is None and predictions is None:
plt.subplots_adjust(wspace=0, hspace=0)
else:
plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
plt.show()
# Visualize model predictions
def dataset_to_numpy_util(dataset, N):
dataset = dataset.unbatch().batch(N)
for images, labels in dataset:
numpy_images = images.numpy()
numpy_labels = labels.numpy()
break;
return numpy_images, numpy_labels
def title_from_label_and_target(label, correct_label):
label = np.argmax(label, axis=-1)
correct = (label == correct_label)
return "{} [{}{}{}]".format(label, str(correct), ', shoud be ' if not correct else '',
correct_label if not correct else ''), correct
def display_one_flower_eval(image, title, subplot, red=False):
plt.subplot(subplot)
plt.axis('off')
plt.imshow(image)
plt.title(title, fontsize=14, color='red' if red else 'black')
return subplot+1
def display_9_images_with_predictions(images, predictions, labels):
subplot=331
plt.figure(figsize=(13,13))
for i, image in enumerate(images):
title, correct = title_from_label_and_target(predictions[i], labels[i])
subplot = display_one_flower_eval(image, title, subplot, not correct)
if i >= 8:
break;
plt.tight_layout()
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()
# Model evaluation
def plot_metrics(history):
fig, axes = plt.subplots(2, 1, sharex='col', figsize=(20, 8))
axes = axes.flatten()
axes[0].plot(history['loss'], label='Train loss')
axes[0].plot(history['val_loss'], label='Validation loss')
axes[0].legend(loc='best', fontsize=16)
axes[0].set_title('Loss')
axes[0].axvline(np.argmin(history['loss']), linestyle='dashed')
axes[0].axvline(np.argmin(history['val_loss']), linestyle='dashed', color='orange')
axes[1].plot(history['accuracy'], label='Train accuracy')
axes[1].plot(history['val_accuracy'], label='Validation accuracy')
axes[1].legend(loc='best', fontsize=16)
axes[1].set_title('Accuracy')
axes[1].axvline(np.argmax(history['accuracy']), linestyle='dashed')
axes[1].axvline(np.argmax(history['val_accuracy']), linestyle='dashed', color='orange')
plt.xlabel('Epochs', fontsize=16)
sns.despine()
plt.show()
Show some validation images
data_augmentation = tf.keras.Sequential(
[
tf.keras.layers.experimental.preprocessing.RandomCrop(HEIGHT, WIDTH),
tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
tf.keras.layers.experimental.preprocessing.RandomRotation(0.25),
tf.keras.layers.experimental.preprocessing.RandomZoom((-0.2, 0)),
tf.keras.layers.experimental.preprocessing.RandomContrast((0.2,0.2))
]
)
func = lambda x,y: (data_augmentation(x), y)
x = (train_ds
.batch(BATCH_SIZE)
.take(1)
.map(func, num_parallel_calls=AUTOTUNE))
show_images(x.unbatch())
%%run_if {GOOGLE}
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras.applications import VGG16
def build_model(base_model, num_class):
inputs = tf.keras.layers.Input(shape=IMG_SIZE)
x = data_augmentation(inputs)
x = base_model(x)
x = tf.keras.layers.Dropout(0.4)(x)
outputs = tf.keras.layers.Dense(num_class, activation="softmax", name="pred")(x)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
return model
efficientnet = EfficientNetB3(
weights = 'imagenet' if TRAIN else None,
include_top = False,
input_shape = IMG_SIZE,
pooling='avg')
efficientnet.trainable = True
model = build_model(base_model=efficientnet, num_class=len(id2label))
model.summary()
The 3rd layer of the Efficient is the Normalization layer, which can be tuned to our new dataset instead of imagenet. Be patient on this one, it does take a bit of time as we're going through the entire training set.
%%run_if {GOOGLE and TRAIN}
if not os.path.exists("000_normalization.h5"):
model.get_layer('efficientnetb3').get_layer('normalization').adapt(adapt_ds_batch)
model.save_weights("000_normalization.h5")
else:
model.load_weights("000_normalization.h5")
CosineDecayRestarts function implemented in tf.keras as it seemed promising and I struggled to find the right settings (if there were any) for the ReduceLROnPlateau
%%run_if {TRAIN}
#@title { run: "auto", display-mode: "form" }
STEPS = math.ceil(len(train_df) / BATCH_SIZE) * EPOCHS
LR_START = 9e-3 #@param {type: "number"}
LR_START *= strategy.num_replicas_in_sync
LR_MIN = 3e-4 #@param {type: "number"}
N_RESTARTS = 5#@param {type: "number"}
T_MUL = 2.0 #@param {type: "number"}
M_MUL = 1#@param {type: "number"}
STEPS_START = math.ceil((T_MUL-1)/(T_MUL**(N_RESTARTS+1)-1) * STEPS)
schedule = tf.keras.experimental.CosineDecayRestarts(
first_decay_steps=STEPS_START,
initial_learning_rate=LR_START,
alpha=LR_MIN,
m_mul=M_MUL,
t_mul=T_MUL)
x = [i for i in range(STEPS)]
y = [schedule(s) for s in range(STEPS)]
_,ax = plt.subplots(1,1,figsize=(8,5),facecolor='#F0F0F0')
ax.plot(x, y)
ax.set_facecolor('#F8F8F8')
ax.set_xlabel('iteration')
ax.set_ylabel('learning rate')
print('{:d} total epochs and {:d} steps per epoch'
.format(EPOCHS, STEPS // EPOCHS))
print(schedule.get_config())
LearningRateScheduler that tensorflow gives us. The LearningRateScheduler update the lr on_epoch_begin while it makes more sense to do it on_batch_end or on_batch_begin.
%%run_if {GOOGLE and TRAIN}
from tensorflow.keras.callbacks import Callback
class LRFinder(Callback):
"""`Callback` that exponentially adjusts the learning rate after
each training batch between `start_lr` and `end_lr` for a maximum number
of batches: `max_step`. The loss and learning rate are recorded at each
step allowing visually finding a good learning rate as
https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html suggested.
"""
def __init__(self, start_lr: float = 1e-7, end_lr: float = 10,
max_steps: int = 100, smoothing=0.9):
super(LRFinder, self).__init__()
self.start_lr, self.end_lr = start_lr, end_lr
self.max_steps = max_steps
self.smoothing = smoothing
self.step, self.best_loss, self.avg_loss, self.lr = 0, 0, 0, 0
self.lrs, self.losses = [], []
def on_train_begin(self, logs=None):
self.step, self.best_loss, self.avg_loss, self.lr = 0, 0, 0, 0
self.lrs, self.losses = [], []
def on_train_batch_begin(self, batch, logs=None):
self.lr = self.exp_annealing(self.step)
tf.keras.backend.set_value(self.model.optimizer.lr, self.lr)
def on_train_batch_end(self, batch, logs=None):
logs = logs or {}
loss = logs.get('loss')
step = self.step
if loss:
self.avg_loss = self.smoothing * self.avg_loss + (1 - self.smoothing) * loss
smooth_loss = self.avg_loss / (1 - self.smoothing ** (self.step + 1))
self.losses.append(smooth_loss)
self.lrs.append(self.lr)
if step == 0 or loss < self.best_loss:
self.best_loss = loss
if smooth_loss > 4 * self.best_loss or tf.math.is_nan(smooth_loss):
self.model.stop_training = True
if step == self.max_steps:
self.model.stop_training = True
self.step += 1
def exp_annealing(self, step):
return self.start_lr * (self.end_lr / self.start_lr) ** (step * 1. / self.max_steps)
def plot(self, skip_end=None):
lrs = self.lrs[:-skip_end] if skip_end else self.lrs[:-5]
losses = self.losses[:-skip_end] if skip_end else self.losses[:-5]
fig, ax = plt.subplots(1, 1, facecolor="#F0F0F0")
ax.set_ylabel('Loss')
ax.set_xlabel('Learning Rate')
ax.set_xscale('log')
ax.xaxis.set_major_formatter(plt.FormatStrFormatter('%.0e'))
ax.plot(lrs, losses)
%%run_if {GOOGLE and TRAIN}
model.compile(loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"])
lr_finder = LRFinder()
_ = model.fit(train_ds_batch, epochs=1, callbacks=[lr_finder])
%%run_if {GOOGLE and TRAIN}
lr_finder.plot(skip_end=20)
As can be observed from the curve, we can pinpoint the lr_max to be 9e-3 and the lr_min to be 3e-4. Let's feed these hyperparams back to the optimizer schedule and retrain the model.
Before retraining, don't forget to reset the model so it can be trained from the 000_normalization.h5 rather than 1 epoch after it because executing the lr_finder
tflearner and have this implemented as a .reset method of a learner class.
%%run_if {GOOGLE and TRAIN}
efficientnet = EfficientNetB3(
weights = 'imagenet',
include_top = False,
input_shape = IMG_SIZE,
pooling='avg')
efficientnet.trainable = True
model = build_model(base_model=efficientnet, num_class=len(id2label))
model.load_weights("000_normalization.h5")
%%run_if {TRAIN}
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath='001_best_model.h5',
monitor='val_loss',
save_best_only=True),
]
%%run_if {TRAIN}
model.compile(loss="sparse_categorical_crossentropy",
optimizer=tf.keras.optimizers.Adam(schedule),
metrics=["accuracy"])
%%run_if {TRAIN}
history = model.fit(train_ds_batch,
epochs = EPOCHS,
validation_data=valid_ds_batch,
callbacks=callbacks)
def show_history(history):
topics = ['loss', 'accuracy']
groups = [{k:v for (k,v) in history.items() if topic in k} for topic in topics]
_,axs = plt.subplots(1,2,figsize=(15,6),facecolor='#F0F0F0')
for topic,group,ax in zip(topics,groups,axs.flatten()):
for (_,v) in group.items(): ax.plot(v)
ax.set_facecolor('#F8F8F8')
ax.set_title(f'{topic} over epochs')
ax.set_xlabel('epoch')
ax.set_ylabel(topic)
ax.legend(['train', 'valid'], loc='best')
%%run_if {TRAIN}
show_history(history.history)
We load the best weight that were kept from the training phase. Just to check how our model is performing, we will attempt predictions over the validation set. This can help to highlight any classes that will be consistently miscategorised.
model.load_weights('{}001_best_model.h5'.format(
'' if TRAIN else '../input/cassava-leaf-disease-classification-models/'))
x = train_df.sample(1).filename.values[0]
img = decode_image(x)
%%time
imgs = [tf.image.random_crop(img, size=IMG_SIZE) for _ in range(4)]
_,axs = plt.subplots(1,4,figsize=(16,4))
for (x, ax) in zip(imgs, axs.flatten()):
ax.imshow(x.numpy().astype(np.uint8))
ax.axis('off')
I apply some very basic test time augmentation to every local image extracted from the original 600-by-800 images. We know we can do some fancy augmentation with albumentations but I wanted to do that exclusively with Keras preprocessing layers to keep the cleanest pipeline possible.
tta = tf.keras.Sequential(
[
tf.keras.layers.experimental.preprocessing.RandomCrop(HEIGHT, WIDTH),
tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
tf.keras.layers.experimental.preprocessing.RandomZoom((-0.2, 0.2)),
tf.keras.layers.experimental.preprocessing.RandomContrast((0.2,0.2))
]
)
def predict_tta(filename, num_tta=4):
img = decode_image(filename)
img = tf.expand_dims(img, 0)
imgs = tf.concat([tta(img) for _ in range(num_tta)], 0)
preds = model.predict(imgs)
return preds.sum(0).argmax()
pred = predict_tta(df.sample(1).filename.values[0])
print(pred)
from tqdm import tqdm
preds = []
with tqdm(total=len(valid_df)) as pbar:
for filename in valid_df.filename:
pbar.update()
preds.append(predict_tta(filename, num_tta=4))
cm = tf.math.confusion_matrix(valid_df.label.values, np.array(preds))
plt.figure(figsize=(10, 8))
sns.heatmap(cm,
xticklabels=id2label.values(),
yticklabels=id2label.values(),
annot=True,
fmt='g',
cmap="Blues")
plt.xlabel('Prediction')
plt.ylabel('Label')
plt.show()
test_folder = input_path + '/test_images/'
submission_df = pd.DataFrame(columns={"image_id","label"})
submission_df["image_id"] = os.listdir(test_folder)
submission_df["label"] = 0
submission_df['label'] = (submission_df['image_id']
.map(lambda x : predict_tta(test_folder+x)))
submission_df
submission_df.to_csv("submission.csv", index=False)